from tvm import te
import tvm

m = te.var('m',"int32")
n = te.var('n',"int32")
k = te.var('k',"int32")
A = te.placeholder((m,),name="A")
B = te.compute((m,),lambda i: A[i] + 1,name="B")
C = te.compute((n,),lambda i: B[i] + 2,name="C")
D = te.compute((k,),lambda i: C[i] + 3,name="D")

s = te.create_schedule(D.op)
s[B].compute_at(s[C], C.op.axis[0])
print(tvm.lower(s,[A,B,C,D],simple_mode=True))